Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO cleanup #1126

Merged
merged 23 commits into from
Jan 23, 2024
Merged

DPO cleanup #1126

merged 23 commits into from
Jan 23, 2024

Conversation

winglian
Copy link
Collaborator

Description

This PR cleans up some hardcoding, improves the integration with trl's DPOTrainer and adds support for dpo prompt_strategies.

src/axolotl/utils/data.py Outdated Show resolved Hide resolved
Copy link
Contributor

@plaguss plaguss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome PR! I left a comment in case you see fit. Also, maybe it could be tackled in a different PR, but the preprocess command could also be updated to allow checking rl datasets:

+    if parsed_cfg.rl:
+        _ = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
+    else:
+        _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
-    _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)

src/axolotl/utils/data.py Outdated Show resolved Hide resolved
@winglian winglian force-pushed the dpo-cleanup branch 2 times, most recently from d5f97c3 to c0a1553 Compare January 23, 2024 02:21

def load(strategy, cfg):
try:
load_fn = strategy.split(".")[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is most likely not correct. The strategy includes underscores, not ., such as intel_apply_chatml.

Copy link
Contributor

@filippo82 filippo82 Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def load(strategy, cfg):
    try:
        load_fn = strategy.split("_")[-1]
        #strategy = ".".join(strategy.split("_")[:-1])
        LOG.info(load_fn)
        LOG.info(strategy)
        mod = importlib.import_module(f".{load_fn}", "axolotl.prompt_strategies.dpo")
        func = getattr(mod, strategy)
        load_kwargs = {}
        return func(cfg, **load_kwargs)
    except Exception as e:  # pylint: disable=broad-exception-caught
        LOG.warning(e)
        return None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the intention is the setting is something like

type: chatml.argilla

in which case it will load the argilla function from the axolotl.prompt_strategies.dpo.chatml module.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @winglian 👋🏻 thanks. That makes sense. I will test it later today 👍🏻

@winglian winglian merged commit 7523d1f into main Jan 23, 2024
1 of 6 checks passed
@winglian winglian deleted the dpo-cleanup branch January 23, 2024 05:40
djsaunde pushed a commit that referenced this pull request Dec 17, 2024
* cleanup dpo to be a little more extensible, add zephyr/nectar strategy

* fix eos slash

* support for eval split

* fix kwargs

* handle empty evals

* don't load peft model for dpo

* ensure dpo traning args gets bf16 for peft if applicable

* fix duplicate kwargs for bf16

* make sure to respect the configured lr scheduler

* supprt trainer callback to push config to wandb

* set dataloader preload args

* ensure that we are loading the lora when merging

* Update src/axolotl/utils/data.py

Co-authored-by: Agus <[email protected]>

* support local datasets for dpo

Co-authored-by: Agus <[email protected]>

* chore: lint

* dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names

* add split to dpo tests

* fix rebase/merging error

* handle edge case w logging

* use accelerator for dpo datasets so it doesn't break the logger

* missing args

* validate checkpoint is an adapter for now

* log warning when dataset strategy is not loadable

---------

Co-authored-by: Agus <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants